diff --git a/xarray/core/concat.py b/xarray/core/concat.py
index 0955a95f..46b3959c 100644
--- a/xarray/core/concat.py
+++ b/xarray/core/concat.py
@@ -440,14 +440,14 @@ def _dataset_concat(
     # we've already verified everything is consistent; now, calculate
     # shared dimension sizes so we can expand the necessary variables
     def ensure_common_dims(vars):
-        # ensure each variable with the given name shares the same
-        # dimensions and the same shape for all of them except along the
-        # concat dimension
-        common_dims = tuple(pd.unique([d for v in vars for d in v.dims]))
+        # Start with the dimensions of the first variable
+        common_dims = list(vars[0].dims)
+        # If the concat dimension is not already in the dimensions, add it at the beginning
         if dim not in common_dims:
-            common_dims = (dim,) + common_dims
+            common_dims.insert(0, dim)
+        # Now set the dimensions for each variable, preserving the original order
         for var, dim_len in zip(vars, concat_dim_lengths):
-            if var.dims != common_dims:
+            if set(var.dims) != set(common_dims):
                 common_shape = tuple(dims_sizes.get(d, dim_len) for d in common_dims)
                 var = var.set_dims(common_dims, common_shape)
             yield var
